{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_5_FCN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "uSGZ6cdmpknm",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"400\">](https://github.com/jeshraghian/snntorch/)\n",
        "\n",
        "\n",
        "# snnTorch - Training Spiking Convolutional Neural Networks with snnTorch\n",
        "## Tutorial 5\n",
        "### By Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_6_CNN.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rugeYYiqsrlc"
      },
      "source": [
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ymi3sqJg28OQ"
      },
      "source": [
        "# Introduction\n",
        "In this tutorial, you will:\n",
        "* Learn how spiking neurons are implemented as a recurrent network\n",
        "* Understand backpropagation through time, and the associated challenges in SNNs such as the non-differentiability of spikes\n",
        "* Train a fully-connected network on the static MNIST dataset\n",
        "\n",
        "<!-- * Implement various backprop strategies:\n",
        "  * Backpropagation Through Time\n",
        "  * Truncated-Backpropagation Through Time\n",
        "  * Real-Time Recurrent Learning -->\n",
        "\n",
        ">Part of this tutorial was inspired by Friedemann Zenke's extensive work on SNNs. Check out his repo on surrogate gradients [here](https://github.com/fzenke/spytorch), and a favourite paper of mine: E. O. Neftci, H. Mostafa, F. Zenke, [Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-based optimization to spiking neural networks.](https://ieeexplore.ieee.org/document/8891809) IEEE Signal Processing Magazine 36, 51–63.\n",
        "\n",
        "At the end of the tutorial, a basic supervised learning algorithm will be implemented. We will use the original static MNIST dataset and train a multi-layer fully-connected spiking neural network using gradient descent to perform image classification. \n",
        "\n",
        "If running in Google Colab:\n",
        "* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`\n",
        "* Next, install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5tn_wUlopkon",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "!pip install snntorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QXZ6Tuqc9Q-l"
      },
      "outputs": [],
      "source": [
        "# imports\n",
        "import snntorch as snn\n",
        "from snntorch import spikeplot as splt\n",
        "from snntorch import spikegen\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import itertools"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gt2xMbLY9dVE"
      },
      "source": [
        "# 1. A Recurrent Representation of SNNs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v7haBG7nA_TC"
      },
      "source": [
        "In Tutorial 3, we derived a recursive representation of a leaky integrate-and-fire (LIF) neuron: \n",
        "\n",
        "$$U[t+1] = \\underbrace{\\beta U[t]}_\\text{decay} + \\underbrace{WX[t+1]}_\\text{input} - \\underbrace{R[t]}_\\text{reset} \\tag{1}$$\n",
        "\n",
        "where the input synaptic current is interpreted as $I_{\\rm in}[t] = WX[t]$, and $X[t]$ may be some arbitrary input of spikes, a step/time-varying voltage, or unweighted step/time-varying current. Spiking is represented with the following equation, where if the membrane potential exceeds the threshold, a spike is emitted:\n",
        "\n",
        "$$S[t] = \\begin{cases} 1, &\\text{if}~U[t] > U_{\\rm thr} \\\\\n",
        "0, &\\text{otherwise}\\end{cases} \\tag{2}$$\n",
        "\n",
        "This formulation of a spiking neuron in a discrete, recursive form is almost perfectly poised to take advantage of the developments in training recurrent neural networks (RNNs) and sequence-based models. This is illustrated using an *implicit* recurrent connection for the decay of the membrane potential, and is distinguished from *explicit* recurrence where the output spike $S_{\\rm out}$ is fed back to the input. In the figure below, the connection weighted by $-U_{\\rm thr}$ represents the reset mechanism $R[t]$.\n",
        "\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/unrolled_2.png?raw=true' width=\"800\">\n",
        "</center>\n",
        "\n",
        "The benefit of an unrolled graph is that it provides an explicit description of how computations are performed. The process of unfolding illustrates the flow of information forward in time (from left to right) to compute outputs and losses, and backward in time to compute gradients. The more time steps that are simulated, the deeper the graph becomes. \n",
        "\n",
        "Conventional RNNs treat $\\beta$ as a learnable parameter. This is also possible for SNNs, though by default, they are treated as hyperparameters. This replaces the vanishing and exploding gradient problems with a hyperparameter search. A future tutorial will describe how to make $\\beta$ a learnable parameter."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wgzf83HE2BeB"
      },
      "source": [
        "# 2. The Non-Differentiability of Spikes\n",
        "## 2.1 Training Using the Backprop Algorithm\n",
        "\n",
        "An alternative way to represent the relationship between $S$ and $U$ in $(2)$ is:\n",
        "\n",
        "$$S[t] = \\Theta(U[t] - U_{\\rm thr}) \\tag{3}$$ \n",
        "\n",
        "where $\\Theta(\\cdot)$ is the Heaviside step function:\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial3/3_2_spike_descrip.png?raw=true' width=\"600\">\n",
        "</center>\n",
        "\n",
        "Training a network in this form poses some serious challenges. Consider a single, isolated time step of the computational graph from the previous figure titled *\"Recurrent representation of spiking neurons\"*, as shown in the *forward pass* below:\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/non-diff.png?raw=true' width=\"400\">\n",
        "</center>\n",
        "\n",
        "The goal is to train the network using the gradient of the loss with respect to the weights, such that the weights are updated to minimize the loss. The backpropagation algorithm achieves this using the chain rule:\n",
        "\n",
        "$$\\frac{\\partial \\mathcal{L}}{\\partial W} = \n",
        "\\frac{\\partial \\mathcal{L}}{\\partial S}\n",
        "\\underbrace{\\frac{\\partial S}{\\partial U}}_{\\{0, \\infty\\}}\n",
        "\\frac{\\partial U}{\\partial I}\\\n",
        "\\frac{\\partial I}{\\partial W}\\ \\tag{4}$$\n",
        "\n",
        "From $(1)$, $\\partial I/\\partial W=X$, and $\\partial U/\\partial I=1$. While we have not yet defined a loss function, we can assume $\\partial \\mathcal{L}/\\partial S$ has an analytical solution, in a similar form to the cross-entropy or mean-square error loss (more on that shortly). \n",
        "\n",
        "However, the term that we are going to grapple with is $\\partial S/\\partial U$. The derivative of the Heaviside step function from $(3)$ is the Dirac Delta function, which evaluates to 0 everywhere, except at the threshold $U_{\\rm thr} = \\theta$, where it tends to infinity. This means the gradient will almost always be nulled to zero (or saturated if $U$ sits precisely at the threshold), and no learning can take place. This is known as the **dead neuron problem**. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mVrM7nLOMvgx"
      },
      "source": [
        "## 2.2 Overcoming the Dead Neuron Problem\n",
        "\n",
        "The most common way to address the dead neuron problem is to keep the Heaviside function as it is during the forward pass, but swap the derivative term $\\partial S/\\partial U$ for something that does not kill the learning process during the backward pass, which will be denoted $\\partial \\tilde{S}/\\partial U$. This might sound odd, but it turns out that neural networks are quite robust to such approximations. This is commonly known as the *surrogate gradient* approach.\n",
        "\n",
        "A variety of options exist to using surrogate gradients, and we will dive into more detail on these methods in [Tutorial 6](https://snntorch.readthedocs.io/en/latest/tutorials/index.html). For now, a simple approximation is applied where $\\partial \\tilde{S}/\\partial U$ is set to $S$ itself.\n",
        "\n",
        "If $S$ does not spike, then the spike-gradient term is $0$. If $S$ spikes, then the gradient term is $1$. This simply looks like the gradient of a ReLU function shifted to the threshold. This method is known as the *Spike-Operator* approach and is described in more detail in the following paper:\n",
        "\n",
        "> <cite> Jason K. Eshraghian, Max Ward, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv, 2021. </cite>\n",
        "\n",
        "Inutitively, *Spike Operator* splits the gradient calculation into two chunks: one where the neuron is spiking, and one where it is silent:\n",
        "* **Silent:** If the neuron is silent, then the spike response can be obtained by scaling the membrane by 0: $S = U \\times 0 \\implies \\partial \\tilde{S}/\\partial U = 0$. \n",
        "* **Spiking:** If the neuron is spiking, then assume $U \\approx U_{\\rm thr}$, normalize $U_{\\rm thr}=1$, and the spike response can be obtained by scaling the membrane by 1: $S = U \\times 1 \\implies \\partial \\tilde{S}/\\partial U = 1$, where the tilde above $\\tilde{S}$ implies an approximation.\n",
        "\n",
        "This is summarized as follows:\n",
        "\n",
        "$$\\frac{\\partial \\tilde{S}}{\\partial U} \\leftarrow S = \\begin{cases} 1, &\\text{if}~U> U_{\\rm thr} \\\\\n",
        "0, &\\text{otherwise}\\end{cases} $$\n",
        "\n",
        "where the left arrow denotes substitution. \n",
        "\n",
        "The same neuron model described in  $(1)-(2)$ (a.k.a., `snn.Leaky` neuron from Tutorial 3) is implemented in PyTorch below. Don't worry if you don't understand this. This will be condensed into one line of code using snnTorch in a moment:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mfJUm-6T8aG2"
      },
      "outputs": [],
      "source": [
        "# Leaky neuron model, overriding the backward pass with a custom function\n",
        "class LeakySurrogate(nn.Module):\n",
        "  def __init__(self, beta, threshold=1.0):\n",
        "      super(LeakySurrogate, self).__init__()\n",
        "\n",
        "      # initialize decay rate beta and threshold\n",
        "      self.beta = beta\n",
        "      self.threshold = threshold\n",
        "      self.spike_op = self.SpikeOperator.apply\n",
        "  \n",
        "  # the forward function is called each time we call Leaky\n",
        "  def forward(self, input_, mem):\n",
        "    spk = self.spike_op((mem-self.threshold))  # call the Heaviside function\n",
        "    reset = (spk * self.threshold).detach() # removes spike_op gradient from reset\n",
        "    mem = self.beta * mem + input_ - reset # Eq (1)\n",
        "    return spk, mem\n",
        "\n",
        "  # Forward pass: Heaviside function\n",
        "  # Backward pass: Override Dirac Delta with the Spike itself\n",
        "  @staticmethod\n",
        "  class SpikeOperator(torch.autograd.Function):\n",
        "      @staticmethod\n",
        "      def forward(ctx, mem):\n",
        "          spk = (mem > 0).float() # Heaviside on the forward pass: Eq(2)\n",
        "          ctx.save_for_backward(spk)  # store the spike for use in the backward pass\n",
        "          return spk\n",
        "\n",
        "      @staticmethod\n",
        "      def backward(ctx, grad_output):\n",
        "          (spk,) = ctx.saved_tensors  # retrieve the spike \n",
        "          grad = grad_output * spk # scale the gradient by the spike: 1/0\n",
        "          return grad"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1zkc1Mmp97OX"
      },
      "source": [
        "Note that the reset mechanism is detached from the computational graph, as the surrogate gradient should only be applied to $\\partial S/\\partial U$, and not $\\partial R/\\partial U$.\n",
        "\n",
        "The above neuron is instantiated using:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EV3lU6soOnW6"
      },
      "outputs": [],
      "source": [
        "lif1 = LeakySurrogate(beta=0.9)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "StklvL_gPns1"
      },
      "source": [
        "This neuron can be simulated using a for-loop, just as in previous tutorials, while PyTorch's automatic differentation (autodiff) mechanism keeps track of the gradient in the background.\n",
        "\n",
        "Alternatively, the same thing can be accomplished by calling the `snn.Leaky` neuron. \n",
        "In fact, every time you call any neuron model from snnTorch, the *Spike Operator* surrogate gradient is applied to it by default:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Wa7N31mP9Va"
      },
      "outputs": [],
      "source": [
        "lif1 = snn.Leaky(beta=0.9)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9EF70Xi1RX6w"
      },
      "source": [
        "If you would like to explore how this neuron behaves, then refer to [Tutorial 3](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cxl1UYSCRzzl"
      },
      "source": [
        "#3. Backprop Through Time\n",
        "Equation $(4)$ only calculates the gradient for one single time step (referred to as the *immediate influence* in the figure below), but the backpropagation through time (BPTT) algorithm calculates the gradient from the loss to *all* descendants and sums them together. \n",
        "\n",
        "The weight $W$ is applied at every time step, and so imagine a loss is also calculated at every time step. The influence of the weight on present and historical losses must be summed together to define the global gradient:\n",
        "\n",
        "$$\\frac{\\partial \\mathcal{L}}{\\partial W}=\\sum_t \\frac{\\partial\\mathcal{L}[t]}{\\partial W} = \n",
        "\\sum_t \\sum_{s\\leq t} \\frac{\\partial\\mathcal{L}[t]}{\\partial W[s]}\\frac{\\partial W[s]}{\\partial W} \\tag{5} $$\n",
        "\n",
        "The point of $(5)$ is to ensure causality: by constraining $s\\leq t$, we only account for the contribution of immediate and prior influences of $W$ on the loss. A recurrent system constrains the weight to be shared across all steps: $W[0]=W[1] =~... ~ = W$. Therefore, a change in $W[s]$ will have the same effect on all $W$, which implies that $\\partial W[s]/\\partial W=1$:\n",
        "\n",
        "$$\\frac{\\partial \\mathcal{L}}{\\partial W}=\n",
        "\\sum_t \\sum_{s\\leq t} \\frac{\\partial\\mathcal{L}[t]}{\\partial W[s]} \\tag{6} $$\n",
        "\n",
        "As an example, isolate the prior influence due to $s = t-1$ *only*; this means the backward pass must track back in time by one step. The influence of $W[t-1]$ on the loss can be written as:\n",
        "\n",
        "$$\\frac{\\partial \\mathcal{L}[t]}{\\partial W[t-1]} = \n",
        "\\frac{\\partial \\mathcal{L}[t]}{\\partial S[t]}\n",
        "\\underbrace{\\frac{\\partial \\tilde{S}[t]}{\\partial U[t]}}_{S[t]}\n",
        "\\underbrace{\\frac{\\partial U[t]}{\\partial U[t-1]}}_\\beta\n",
        "\\underbrace{\\frac{\\partial U[t-1]}{\\partial I[t-1]}}_1\n",
        "\\underbrace{\\frac{\\partial I[t-1]}{\\partial W[t-1]}}_{X[t]} \\tag{7}$$\n",
        "\n",
        "We have already dealt with all of these terms from $(4)$, except for $\\partial U[t]/\\partial U[t-1]$. From $(1)$, this temporal derivative term simply evaluates to $\\beta$. So if we really wanted to, we now know enough to painstakingly calculate the derivative of every weight at every time step by hand, and it'd look something like this for a single neuron:\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial5/bptt.png?raw=true' width=\"600\">\n",
        "</center>\n",
        "\n",
        "But thankfully, PyTorch's autodiff takes care of that in the background for us.  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c_42-CbsZ1FM"
      },
      "source": [
        "# 4. Setting up the Loss / Output Decoding\n",
        "In a conventional, non-spiking neural network, a supervised, multi-class classification problem takes the neuron with the highest activation and treats that as the predicted class. \n",
        "\n",
        "In a spiking neural net, there are several options to interpreting the output spikes. The most common approaches are:\n",
        "* **Rate coding:** Take the neuron with the highest firing rate (or spike count) as the predicted class\n",
        "* **Latency coding:** Take the neuron that fires *first* as the predicted class\n",
        "\n",
        "This might feel familiar to [Tutorial 1 on neural encoding](https://snntorch.readthedocs.io/en/latest/tutorials/index.html). The difference is that, here, we are interpreting (decoding) the output spikes, rather than encoding/converting raw input data into spikes.\n",
        "\n",
        "Let's focus on a rate code. When input data is passed to the network, we want the correct neuron class to emit the most spikes over the course of the simulation run. This then corresponds to the highest average firing frequency. One way to achieve this is to increase the membrane potential of the correct class to $U>U_{\\rm thr}$, and that of incorrect classes to $U<U_{\\rm thr}$. Applying the target to $U$ serves as a proxy for modulating spiking behavior from $S$.\n",
        "\n",
        "This can be implemented by taking the softmax of the membrane potential for output neurons, where $C$ is the number of output classes:\n",
        "\n",
        "$$p_i[t] = \\frac{e^{U_i[t]}}{\\sum_{i=0}^{C}e^{U_i[t]}} \\tag{8}$$\n",
        "\n",
        "The cross-entropy between $p_i$ and the target $y_i \\in \\{0,1\\}^C$, which is a one-hot target vector, is obtained using:\n",
        "\n",
        "$$\\mathcal{L}_{CE}[t] = \\sum_{i=0}^Cy_i{\\rm log}(p_i[t]) \\tag{9}$$\n",
        "\n",
        "The practical effect is that the membrane potential of the correct class is encouraged to increase while those of incorrect classes are reduced. In effect, this means the correct class is encouraged to fire at all time steps, while incorrect classes are suppressed at all steps. This may not be the most efficient implementation of an SNN, but it is among the simplest.\n",
        "\n",
        "This target is applied at every time step of the simulation, thus also generating a loss at every step. These losses are then summed together at the end of the simulation:\n",
        "\n",
        "$$\\mathcal{L}_{CE} = \\sum_t\\mathcal{L}_{CE}[t] \\tag{10}$$\n",
        "\n",
        "This is just one of many possible ways to apply a loss function to a spiking neural network. A variety of approaches are available to use in snnTorch (in the module `snn.functional`), and will be the subject of a future tutorial.\n",
        "\n",
        "With all of the background theory having been taken care of, let’s finally dive into\n",
        "training a fully-connected spiking neural net."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zqJdfllYbc16"
      },
      "source": [
        "# 5. Setting up the Static MNIST Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lI0GbgLgpkos",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# dataloader arguments\n",
        "batch_size = 128\n",
        "data_path='/data/mnist'\n",
        "\n",
        "dtype = torch.float\n",
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2fhRixcspkot",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Define a transform\n",
        "transform = transforms.Compose([\n",
        "            transforms.Resize((28, 28)),\n",
        "            transforms.Grayscale(),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0,), (1,))])\n",
        "\n",
        "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
        "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RAM_dP887uTq"
      },
      "source": [
        "If the above code blocks throws an error, e.g. the MNIST servers are down, then uncomment the following code instead."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4jyJVqUNdXDo"
      },
      "outputs": [],
      "source": [
        "# # temporary dataloader if MNIST service is unavailable\n",
        "# !wget www.di.ens.fr/~lelarge/MNIST.tar.gz\n",
        "# !tar -zxvf MNIST.tar.gz\n",
        "\n",
        "# mnist_train = datasets.MNIST(root = './', train=True, download=True, transform=transform)\n",
        "# mnist_test = datasets.MNIST(root = './', train=False, download=True, transform=transform)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aEtCbO6upkou",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Create DataLoaders\n",
        "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GhFyzySNeT_e"
      },
      "source": [
        "# 6. Define the Network"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lud3kywn55fj"
      },
      "outputs": [],
      "source": [
        "# Network Architecture\n",
        "num_inputs = 28*28\n",
        "num_hidden = 1000\n",
        "num_outputs = 10\n",
        "\n",
        "# Temporal Dynamics\n",
        "num_steps = 25\n",
        "beta = 0.95"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-uquHLLmpkox",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Define Network\n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "        # Initialize layers\n",
        "        self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
        "        self.lif1 = snn.Leaky(beta=beta)\n",
        "        self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
        "        self.lif2 = snn.Leaky(beta=beta)\n",
        "\n",
        "    def forward(self, x):\n",
        "\n",
        "        # Initialize hidden states at t=0\n",
        "        mem1 = self.lif1.init_leaky()\n",
        "        mem2 = self.lif2.init_leaky()\n",
        "        \n",
        "        # Record the final layer\n",
        "        spk2_rec = []\n",
        "        mem2_rec = []\n",
        "\n",
        "        for step in range(num_steps):\n",
        "            cur1 = self.fc1(x)\n",
        "            spk1, mem1 = self.lif1(cur1, mem1)\n",
        "            cur2 = self.fc2(spk1)\n",
        "            spk2, mem2 = self.lif2(cur2, mem2)\n",
        "            spk2_rec.append(spk2)\n",
        "            mem2_rec.append(mem2)\n",
        "\n",
        "        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n",
        "        \n",
        "# Load the network onto CUDA if available\n",
        "net = Net().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y0fHcAKfrav6"
      },
      "source": [
        "The code in the `forward()` function will only be called once the input argument `x` is explicitly passed into `net`.\n",
        "\n",
        "* `fc1` applies a linear transformation to all input pixels from the MNIST dataset;\n",
        "* `lif1` integrates the weighted input over time, emitting a spike if the threshold condition is met;\n",
        "* `fc2` applies a linear transformation to the output spikes of `lif1`;\n",
        "* `lif2` is another spiking neuron layer, integrating the weighted spikes over time."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6a7MdORCtIx4"
      },
      "source": [
        "# 7. Training the SNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6D-fhT3Q7nXM"
      },
      "source": [
        "## 7.1 Accuracy Metric\n",
        "Below is a function that takes a batch of data, counts up all the spikes from each neuron (i.e., a rate code over the simulation time), and compares the index of the highest count with the actual target. If they match, then the network correctly predicted the target."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-IxcnBAxpkoy",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# pass data into the network, sum the spikes over time\n",
        "# and compare the neuron with the highest number of spikes\n",
        "# with the target\n",
        "\n",
        "def print_batch_accuracy(data, targets, train=False):\n",
        "    output, _ = net(data.view(batch_size, -1))\n",
        "    _, idx = output.sum(dim=0).max(1)\n",
        "    acc = np.mean((targets == idx).detach().cpu().numpy())\n",
        "\n",
        "    if train:\n",
        "        print(f\"Train set accuracy for a single minibatch: {acc*100:.2f}%\")\n",
        "    else:\n",
        "        print(f\"Test set accuracy for a single minibatch: {acc*100:.2f}%\")\n",
        "\n",
        "def train_printer():\n",
        "    print(f\"Epoch {epoch}, Iteration {iter_counter}\")\n",
        "    print(f\"Train Set Loss: {loss_hist[counter]:.2f}\")\n",
        "    print(f\"Test Set Loss: {test_loss_hist[counter]:.2f}\")\n",
        "    print_batch_accuracy(data, targets, train=True)\n",
        "    print_batch_accuracy(test_data, test_targets, train=False)\n",
        "    print(\"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "woJSGSx68tsd"
      },
      "source": [
        "## 7.2 Loss Definition\n",
        "The `nn.CrossEntropyLoss` function in PyTorch automatically handles taking the softmax of the output layer as well as generating a loss at the output. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iqdVyjCNtdlp"
      },
      "outputs": [],
      "source": [
        "loss = nn.CrossEntropyLoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b1fPgSoO9Jgb"
      },
      "source": [
        "## 7.3 Optimizer\n",
        "Adam is a robust optimizer that performs well on recurrent networks, so let's use that with a learning rate of $5\\times10^{-4}$. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l62ZR51s9Lxg"
      },
      "outputs": [],
      "source": [
        "optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GiqAVKzVbfPn"
      },
      "source": [
        "## 7.4 One Iteration of Training\n",
        "Take the first batch of data and load it onto CUDA if available."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hv1q2-Mt9kVi"
      },
      "outputs": [],
      "source": [
        "data, targets = next(iter(train_loader))\n",
        "data = data.to(device)\n",
        "targets = targets.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cFahTbAv-Vtt"
      },
      "source": [
        "Flatten the input data to a vector of size $784$ and pass it into the network."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lltqTEXE92V-"
      },
      "outputs": [],
      "source": [
        "spk_rec, mem_rec = net(data.view(batch_size, -1))\n",
        "print(mem_rec.size())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wd_qv6xD-lCb"
      },
      "source": [
        "The recording of the membrane potential is taken across:\n",
        "* 25 time steps\n",
        "* 128 samples of data\n",
        "* 10 output neurons\n",
        "\n",
        "We wish to calculate the loss at every time step, and sum these up together, as per Equation $(10)$:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nsnH8y5G-D-z"
      },
      "outputs": [],
      "source": [
        "# initialize the total loss value\n",
        "loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
        "\n",
        "# sum loss at every step\n",
        "for step in range(num_steps):\n",
        "  loss_val += loss(mem_rec[step], targets)\n",
        "\n",
        "print(f\"Training loss: {loss_val.item():.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q4r0sKMV_4ri"
      },
      "source": [
        "The loss is quite large, because it is summed over 25 time steps. The accuracy is also bad (it should be roughly around 10%) as the network is untrained:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qetPvz7mAArd"
      },
      "outputs": [],
      "source": [
        "print_batch_accuracy(data, targets, train=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fUcR0GcUAtPn"
      },
      "source": [
        "A single weight update is applied to the network as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WxyBhsmlAsWM"
      },
      "outputs": [],
      "source": [
        "  # clear previously stored gradients\n",
        "  optimizer.zero_grad()\n",
        "\n",
        "  # calculate the gradients\n",
        "  loss_val.backward()\n",
        "\n",
        "  # weight update\n",
        "  optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ubyude8eA5p9"
      },
      "source": [
        "Now, re-run the loss calculation and accuracy after a single iteration:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l4ZquRR9A9He"
      },
      "outputs": [],
      "source": [
        "# calculate new network outputs using the same data\n",
        "spk_rec, mem_rec = net(data.view(batch_size, -1))\n",
        "\n",
        "# initialize the total loss value\n",
        "loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
        "\n",
        "# sum loss at every step\n",
        "for step in range(num_steps):\n",
        "  loss_val += loss(mem_rec[step], targets)\n",
        "\n",
        "print(f\"Training loss: {loss_val.item():.3f}\")\n",
        "print_batch_accuracy(data, targets, train=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbNPCNTSBaW3"
      },
      "source": [
        "After only one iteration, the loss should have decreased and accuracy should have increased. Note how membrane potential is used to calculate the cross entropy\n",
        "loss, and spike count is used for the measure of accuracy. It is also possible to use the spike count in the loss ([see Tutorial 6](https://snntorch.readthedocs.io/en/latest/tutorials/index.html))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mVgKDes8BiXq"
      },
      "source": [
        "## 7.5 Training Loop\n",
        "\n",
        "Let's combine everything into a training loop. We will train for one epoch (though feel free to increase `num_epochs`), exposing our network to each sample of data once."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LMZMxEV8dcTC"
      },
      "outputs": [],
      "source": [
        "num_epochs = 1\n",
        "loss_hist = []\n",
        "test_loss_hist = []\n",
        "counter = 0\n",
        "\n",
        "# Outer training loop\n",
        "for epoch in range(num_epochs):\n",
        "    iter_counter = 0\n",
        "    train_batch = iter(train_loader)\n",
        "\n",
        "    # Minibatch training loop\n",
        "    for data, targets in train_batch:\n",
        "        data = data.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        # forward pass\n",
        "        net.train()\n",
        "        spk_rec, mem_rec = net(data.view(batch_size, -1))\n",
        "\n",
        "        # initialize the loss & sum over time\n",
        "        loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
        "        for step in range(num_steps):\n",
        "            loss_val += loss(mem_rec[step], targets)\n",
        "\n",
        "        # Gradient calculation + weight update\n",
        "        optimizer.zero_grad()\n",
        "        loss_val.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Store loss history for future plotting\n",
        "        loss_hist.append(loss_val.item())\n",
        "\n",
        "        # Test set\n",
        "        with torch.no_grad():\n",
        "            net.eval()\n",
        "            test_data, test_targets = next(iter(test_loader))\n",
        "            test_data = test_data.to(device)\n",
        "            test_targets = test_targets.to(device)\n",
        "\n",
        "            # Test set forward pass\n",
        "            test_spk, test_mem = net(test_data.view(batch_size, -1))\n",
        "\n",
        "            # Test set loss\n",
        "            test_loss = torch.zeros((1), dtype=dtype, device=device)\n",
        "            for step in range(num_steps):\n",
        "                test_loss += loss(test_mem[step], test_targets)\n",
        "            test_loss_hist.append(test_loss.item())\n",
        "\n",
        "            # Print train/test loss/accuracy\n",
        "            if counter % 50 == 0:\n",
        "                train_printer()\n",
        "            counter += 1\n",
        "            iter_counter +=1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Taf6WZLojHTz"
      },
      "source": [
        "If this was your first time training an SNN, then congratulations!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "HxU7P7xFpko3",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# 8. Results\n",
        "## 8.1 Plot Training/Test Loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Pk_EScnpkpj",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Plot Loss\n",
        "fig = plt.figure(facecolor=\"w\", figsize=(10, 5))\n",
        "plt.plot(loss_hist)\n",
        "plt.plot(test_loss_hist)\n",
        "plt.title(\"Loss Curves\")\n",
        "plt.legend([\"Train Loss\", \"Test Loss\"])\n",
        "plt.xlabel(\"Iteration\")\n",
        "plt.ylabel(\"Loss\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g-Gd84OAl1rB"
      },
      "source": [
        "The loss curves are noisy because the losses are tracked at every iteration, rather than averaging across multiple iterations. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "Z3f0vBnBpkpk",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## 8.2 Test Set Accuracy\n",
        "This function iterates over all minibatches to obtain a measure of accuracy over the full 10,000 samples in the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F5Rb4xHGndQh"
      },
      "outputs": [],
      "source": [
        "total = 0\n",
        "correct = 0\n",
        "\n",
        "# drop_last switched to False to keep all samples\n",
        "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)\n",
        "\n",
        "with torch.no_grad():\n",
        "  net.eval()\n",
        "  for data, targets in test_loader:\n",
        "    data = data.to(device)\n",
        "    targets = targets.to(device)\n",
        "    \n",
        "    # forward pass\n",
        "    test_spk, _ = net(data.view(data.size(0), -1))\n",
        "\n",
        "    # calculate total accuracy\n",
        "    _, predicted = test_spk.sum(dim=0).max(1)\n",
        "    total += targets.size(0)\n",
        "    correct += (predicted == targets).sum().item()\n",
        "\n",
        "print(f\"Total correctly classified test set images: {correct}/{total}\")\n",
        "print(f\"Test Set Accuracy: {100 * correct / total:.2f}%\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "TBIXau4Zpkpl",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "Voila! That's it for static MNIST. Feel free to tweak the network parameters, hyperparameters, decay rate, using a learning rate scheduler etc. to see if you can improve the network performance. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s0dAgWUt2o6E"
      },
      "source": [
        "# Conclusion\n",
        "Now you know how to construct and train a fully-connected network on a static dataset. The spiking neurons can also be adapted to other layer types, including convolutions and skip connections. Armed with this knowledge, you should now be able to build many different types of SNNs. [In the next tutorial](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), you will learn how to train a spiking convolutional network, and simplify the amount of code required using the `snn.backprop` module.\n",
        "\n",
        "If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Additional Resources\n",
        "\n",
        "* [Check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "include_colab_link": true,
      "name": "Copy of tutorial_2_FCN_truncatedfromscratch.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 2
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython2",
      "version": "2.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
